import torch
import numpy as np
import copy


@torch.no_grad()
def generate_encoder_attention_maps(model, image, device: str = 'cuda'):
    ids_shuffle = np.arange(196)[None,:]
    # make it a batch-like
    x = image.to(device)
    
    # embed patches
    x = model.patch_embed(x)

    # add pos embed w/o cls token
    x = x + model.pos_embed[:, 1:, :]

    # masking: length -> length * mask_ratio
    x, _, _ = model.convert_masking(x, ids_shuffle, mask_ratio=0)
    # append cls token
    cls_token = model.cls_token + model.pos_embed[:, :1, :]
    cls_tokens = cls_token.expand(x.shape[0], -1, -1)
    x = torch.cat((cls_tokens, x), dim=1)
    attns = []
    # apply Transformer blocks
    for blk in model.blocks:
        x_temp = blk.norm1(x)
        # Attention layer:
        B, N, C = x_temp.shape
        qkv = blk.attn.qkv(x_temp).reshape(B, N, 3, blk.attn.num_heads, C // blk.attn.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * blk.attn.scale
        attn = attn.softmax(dim=-1)
        # The attention shape is [1, 16, 197, 197]
        attns.append(attn.detach().cpu().numpy())
        attn = blk.attn.attn_drop(attn)

        x_temp = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x_temp = blk.attn.proj(x_temp)
        x_temp = blk.attn.proj_drop(x_temp)
        # Done with attention layer
        x = x + x_temp
        x = x + blk.mlp(blk.norm2(x))
    return attns


@torch.no_grad()
def generate_image(orig_image, model, ids_shuffle, len_keep: int, device: str = 'cpu'):
    """ids_shuffle is [bs, 196]"""
    ids_shuffle = ids_shuffle.to(device)
    # make it a batch-like
    orig_image = convert_to_tensor(orig_image).to(device)
    temp_x = orig_image.clone().detach().to(device)

    # RUN ENCODER:
    # embed patches
    latent = model.patch_embed(temp_x.float())

    # add pos embed w/o cls token
    latent = latent + model.pos_embed[:, 1:, :]

    # masking: length -> length * mask_ratio
    N, L, D = latent.shape  # batch, length, dim
    # sort noise for each sample
    ids_restore = torch.argsort(ids_shuffle, dim=1)

    # keep the first subset
    ids_keep = ids_shuffle[:, :len_keep]
    latent = torch.gather(
        latent, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

    # generate the binary mask: 0 is keep, 1 is remove
    mask = torch.ones([N, L], device=latent.device)
    mask[:, :len_keep] = 0
    # unshuffle to get the binary mask
    mask = torch.gather(mask, dim=1, index=ids_restore)

    # append cls token
    cls_token = model.cls_token + model.pos_embed[:, :1, :]
    cls_tokens = cls_token.expand(latent.shape[0], -1, -1)
    latent = torch.cat((cls_tokens, latent), dim=1)

    # apply Transformer blocks
    for blk in model.blocks:
        latent = blk(latent)
    latent = model.norm(latent)
    x = model.decoder_embed(latent)

    # append mask tokens to sequence
    mask_tokens = model.mask_token.repeat(
        x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
    x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token
    x_ = torch.gather(
        x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle
    x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token
    # add pos embed
    x = x + model.decoder_pos_embed

    # apply Transformer blocks
    for block_num, blk in enumerate(model.decoder_blocks):
        # Here is unrollment of the decoder blocks:
        x_temp = blk.norm1(x)
        # here is an unrollment of the attention mechanism:
        B, N, C = x_temp.shape
        qkv = blk.attn.qkv(x_temp).reshape(
            B, N, 3, blk.attn.num_heads, C // blk.attn.num_heads).permute(2, 0, 3, 1, 4)
        # make torchscript happy (cannot use tensor as tuple)
        q, k, v = qkv.unbind(0)
        attn = (q @ k.transpose(-2, -1)) * blk.attn.scale
        # The attention shape is [1, 16, 197, 197]
        # This is where our code comes to mind:
        attn = attn.softmax(dim=-1)

        x_temp = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x_temp = blk.attn.proj(x_temp)
        x_temp = blk.attn.proj_drop(x_temp)
        # Here we continue to the orignal block.
        x = x + blk.drop_path(x_temp)

        x = x + blk.drop_path(blk.mlp(blk.norm2(x)))
    x = model.decoder_norm(x)

    # predictor projection
    x = model.decoder_pred(x)

    # remove cls token
    x = x[:, 1:, :]

    y = model.unpatchify(x)
    y = torch.einsum('nchw->nhwc', y).detach()

    # visualize the mask
    mask = mask.detach()
    mask = mask.unsqueeze(-1).repeat(1, 1,
                                     model.patch_embed.patch_size[0]**2 * 3)
    mask = model.unpatchify(mask)  # 1 is removing, 0 is keeping
    mask = torch.einsum('nchw->nhwc', mask).detach()

    orig_image = torch.einsum('nchw->nhwc', orig_image)

    # MAE reconstruction pasted with visible patches
    im_paste = orig_image * (1 - mask) + y * mask
    return orig_image, im_paste, mask, y


imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])


def fill_to_full(arr):
    new_arr = copy.deepcopy(arr)
    if isinstance(new_arr, np.ndarray):
        new_arr = list(new_arr)
    for i in range(196):
        if i not in new_arr:
            new_arr.append(i)
    return torch.tensor(new_arr)[np.newaxis, ]



def convert_to_tensor(img):
    if isinstance(img, np.ndarray):
        img = torch.tensor(img)
    if len(img.shape) != 4:
        # make it a batch-like
        img = img.unsqueeze(dim=0)
        img = torch.einsum('nhwc->nchw', img)
    elif img.shape[-1] == 3:
        assert isinstance(img, torch.Tensor)
        img = torch.einsum('nhwc->nchw', img)
    return img